Generative Modeling¶

Lecture Date: 2025-03-14¶

Four models are detailed in this notebook: make sure to prepare the DataLoaders first, and then you can jump to any of the other model sections and run them. There is some overlap between the implementations, but each one utilizes uniques changes to the architecture and/or training regime for that particular approach.

  1. Prepare MNIST DataLoaders (do this first, then jump to a model) (MNIST Data Prep)
  2. Generative Adversarial Network (GAN)
  3. Auxiliary Classifier Adversarial Network (AC-GAN)
  4. Denoising Diffusion Probabilistic Model (DDPM)
  5. Conditioned Denoising Diffusion Probabilitistic Model (CDDPM)
In [1]:
import numpy as np
import torch
import torchvision
import torchmetrics
import lightning.pytorch as pl
from torchinfo import summary
from torchview import draw_graph
import matplotlib.pyplot as plt
import pandas as pd
In [2]:
# Minor tensor-core speedup
torch.set_float32_matmul_precision('medium')

MNIST Data Prep¶

Works with all 4 models posted below...¶

In [3]:
class GenerativeMNISTDataModule(pl.LightningDataModule):
    def __init__(self,
                 batch_size = 128,
                 num_workers = 4,
                 location = "~/datasets",
                 **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.location = "~/datasets"
        self.input_shape = (1,28,28)
        self.output_shape = (1,28,28)
        self.data_train = None
        self.data_test = None
        self.data_predict = None
        
    def setup(self, stage: str):
        if stage == 'fit' and not self.data_train:
            training_dataset = torchvision.datasets.MNIST(root=self.location,download=True, train=True)
            x_train = np.expand_dims(training_dataset.data,1)
            x_train = (x_train / 127.5) - 1.0
            y_train = np.array(training_dataset.targets)
            self.data_train = list(zip(torch.Tensor(x_train).to(torch.float32),
                                       torch.Tensor(y_train).to(torch.long)))
        if (stage == 'test' or \
            stage == 'predict') and \
            not(self.data_test and self.data_predict):
            testing_dataset = torchvision.datasets.MNIST(root=self.location,download=True, train=False)
            x_test = np.expand_dims(testing_dataset.data,1)
            x_test = (x_test / 127.5) - 1.0
            y_test = np.array(testing_dataset.targets)
            self.data_test = list(zip(torch.Tensor(x_test).to(torch.float32),
                                      torch.Tensor(y_test).to(torch.long)))
            self.data_predict = list(zip(torch.zeros(x_test[:80].shape).to(torch.float32),
                                         torch.repeat_interleave(torch.arange(10),8).to(torch.long)))
            
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=True)
    # Doesn't currently exist
    # def val_dataloader(self):
    #     return torch.utils.data.DataLoader(self.data_val,
    #                                        batch_size=self.batch_size,
    #                                        num_workers=self.num_workers,
    #                                        shuffle=False)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=False)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.data_predict,
                                           batch_size=self.batch_size,
                                           num_workers=self.num_workers,
                                           shuffle=False)
In [4]:
data_module = GenerativeMNISTDataModule()

GAN¶

Adapted from: https://lightning.ai/docs/pytorch/2.1.0/notebooks/lightning_examples/basic-gan.html

The model below implements a Generative Adversarial Network (GAN) for learning to generate MNIST digits. The model is not significantly prone to mode collapse: this is largely avoided here by using a relatively large latent space size and residually-connected bottleneck convolution networks for both the generator and discriminator components. (The original code linked-to above would rarely converge and typically exhibited extreme mode collapse due to using only standard, non-residual layers).

In [5]:
class ResidualConv2dLayer(pl.LightningModule):
    def __init__(self,
                 size,
                 **kwargs):
        super().__init__(**kwargs)
        self.padding = torch.nn.ZeroPad2d((1,1,1,1))
        self.conv_layer = torch.nn.Conv2d(size,
                                          size,
                                          kernel_size=(3,3),
                                          stride=(1,1))
        self.norm = torch.nn.BatchNorm2d(size)
        self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        y = x
        y = self.padding(y)
        y = self.conv_layer(y)
        y = self.norm(y)
        y = self.activation(y)
        y = x + y
        return y
In [6]:
class Generator(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 n_channels = 1,
                 **kwargs):
        super().__init__(**kwargs)
        self.model = torch.nn.Sequential(*(
            [torch.nn.Linear(latent_dim,latent_dim*4),
             torch.nn.Unflatten(1,(latent_dim,2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2),
             torch.nn.Conv2d(latent_dim,latent_dim,(2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Conv2d(latent_dim,n_channels,1),
             torch.nn.Tanh()]        
        ))
    def forward(self, x):
        y = x
        y = self.model(y)
        return y
In [7]:
temp_g = Generator(128)
summary(temp_g,
        input_size=(5,128))
Out[7]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Generator                                [5, 1, 28, 28]            --
├─Sequential: 1-1                        [5, 1, 28, 28]            --
│    └─Linear: 2-1                       [5, 512]                  66,048
│    └─Unflatten: 2-2                    [5, 128, 2, 2]            --
│    └─ResidualConv2dLayer: 2-3          [5, 128, 2, 2]            --
│    │    └─ZeroPad2d: 3-1               [5, 128, 4, 4]            --
│    │    └─Conv2d: 3-2                  [5, 128, 2, 2]            147,584
│    │    └─BatchNorm2d: 3-3             [5, 128, 2, 2]            256
│    │    └─LeakyReLU: 3-4               [5, 128, 2, 2]            --
│    └─ResidualConv2dLayer: 2-4          [5, 128, 2, 2]            --
│    │    └─ZeroPad2d: 3-5               [5, 128, 4, 4]            --
│    │    └─Conv2d: 3-6                  [5, 128, 2, 2]            147,584
│    │    └─BatchNorm2d: 3-7             [5, 128, 2, 2]            256
│    │    └─LeakyReLU: 3-8               [5, 128, 2, 2]            --
│    └─Upsample: 2-5                     [5, 128, 4, 4]            --
│    └─ResidualConv2dLayer: 2-6          [5, 128, 4, 4]            --
│    │    └─ZeroPad2d: 3-9               [5, 128, 6, 6]            --
│    │    └─Conv2d: 3-10                 [5, 128, 4, 4]            147,584
│    │    └─BatchNorm2d: 3-11            [5, 128, 4, 4]            256
│    │    └─LeakyReLU: 3-12              [5, 128, 4, 4]            --
│    └─ResidualConv2dLayer: 2-7          [5, 128, 4, 4]            --
│    │    └─ZeroPad2d: 3-13              [5, 128, 6, 6]            --
│    │    └─Conv2d: 3-14                 [5, 128, 4, 4]            147,584
│    │    └─BatchNorm2d: 3-15            [5, 128, 4, 4]            256
│    │    └─LeakyReLU: 3-16              [5, 128, 4, 4]            --
│    └─Upsample: 2-8                     [5, 128, 8, 8]            --
│    └─Conv2d: 2-9                       [5, 128, 7, 7]            65,664
│    └─ResidualConv2dLayer: 2-10         [5, 128, 7, 7]            --
│    │    └─ZeroPad2d: 3-17              [5, 128, 9, 9]            --
│    │    └─Conv2d: 3-18                 [5, 128, 7, 7]            147,584
│    │    └─BatchNorm2d: 3-19            [5, 128, 7, 7]            256
│    │    └─LeakyReLU: 3-20              [5, 128, 7, 7]            --
│    └─ResidualConv2dLayer: 2-11         [5, 128, 7, 7]            --
│    │    └─ZeroPad2d: 3-21              [5, 128, 9, 9]            --
│    │    └─Conv2d: 3-22                 [5, 128, 7, 7]            147,584
│    │    └─BatchNorm2d: 3-23            [5, 128, 7, 7]            256
│    │    └─LeakyReLU: 3-24              [5, 128, 7, 7]            --
│    └─Upsample: 2-12                    [5, 128, 14, 14]          --
│    └─ResidualConv2dLayer: 2-13         [5, 128, 14, 14]          --
│    │    └─ZeroPad2d: 3-25              [5, 128, 16, 16]          --
│    │    └─Conv2d: 3-26                 [5, 128, 14, 14]          147,584
│    │    └─BatchNorm2d: 3-27            [5, 128, 14, 14]          256
│    │    └─LeakyReLU: 3-28              [5, 128, 14, 14]          --
│    └─ResidualConv2dLayer: 2-14         [5, 128, 14, 14]          --
│    │    └─ZeroPad2d: 3-29              [5, 128, 16, 16]          --
│    │    └─Conv2d: 3-30                 [5, 128, 14, 14]          147,584
│    │    └─BatchNorm2d: 3-31            [5, 128, 14, 14]          256
│    │    └─LeakyReLU: 3-32              [5, 128, 14, 14]          --
│    └─Upsample: 2-15                    [5, 128, 28, 28]          --
│    └─ResidualConv2dLayer: 2-16         [5, 128, 28, 28]          --
│    │    └─ZeroPad2d: 3-33              [5, 128, 30, 30]          --
│    │    └─Conv2d: 3-34                 [5, 128, 28, 28]          147,584
│    │    └─BatchNorm2d: 3-35            [5, 128, 28, 28]          256
│    │    └─LeakyReLU: 3-36              [5, 128, 28, 28]          --
│    └─ResidualConv2dLayer: 2-17         [5, 128, 28, 28]          --
│    │    └─ZeroPad2d: 3-37              [5, 128, 30, 30]          --
│    │    └─Conv2d: 3-38                 [5, 128, 28, 28]          147,584
│    │    └─BatchNorm2d: 3-39            [5, 128, 28, 28]          256
│    │    └─LeakyReLU: 3-40              [5, 128, 28, 28]          --
│    └─Conv2d: 2-18                      [5, 1, 28, 28]            129
│    └─Tanh: 2-19                        [5, 1, 28, 28]            --
==========================================================================================
Total params: 1,610,241
Trainable params: 1,610,241
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 1.57
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 21.79
Params size (MB): 6.44
Estimated Total Size (MB): 28.23
==========================================================================================
In [8]:
# model_graph = draw_graph(temp_g,
#                          input_size=(5,128),
#                          depth=3)
# model_graph.visual_graph

image.png

In [9]:
class Discriminator(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 n_channels=1,
                 **kwargs):
        super().__init__(**kwargs)

        self.model = torch.nn.Sequential(*(
            [torch.nn.Conv2d(n_channels,latent_dim,1)]+
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 28
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 14
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 7
            [torch.nn.ZeroPad2d((1,1,1,1)), torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 4
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 2
            [torch.nn.AvgPool2d((2,2)),
             torch.nn.Flatten(),
             torch.nn.Linear(latent_dim,1),
             torch.nn.Sigmoid()]
        ))

    def forward(self, x):
        y = x
        y = self.model(y)
        return y
In [10]:
temp_d = Discriminator(128)
summary(temp_d,input_size=(data_module.batch_size,)+data_module.input_shape)
Out[10]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Discriminator                            [128, 1]                  --
├─Sequential: 1-1                        [128, 1]                  --
│    └─Conv2d: 2-1                       [128, 128, 28, 28]        256
│    └─ResidualConv2dLayer: 2-2          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-1               [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-2                  [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-3             [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-4               [128, 128, 28, 28]        --
│    └─ResidualConv2dLayer: 2-3          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-5               [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-6                  [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-7             [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-8               [128, 128, 28, 28]        --
│    └─AvgPool2d: 2-4                    [128, 128, 14, 14]        --
│    └─ResidualConv2dLayer: 2-5          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-9               [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-10                 [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-11            [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-12              [128, 128, 14, 14]        --
│    └─ResidualConv2dLayer: 2-6          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-13              [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-14                 [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-15            [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-16              [128, 128, 14, 14]        --
│    └─AvgPool2d: 2-7                    [128, 128, 7, 7]          --
│    └─ResidualConv2dLayer: 2-8          [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-17              [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-18                 [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-19            [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-20              [128, 128, 7, 7]          --
│    └─ResidualConv2dLayer: 2-9          [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-21              [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-22                 [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-23            [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-24              [128, 128, 7, 7]          --
│    └─ZeroPad2d: 2-10                   [128, 128, 9, 9]          --
│    └─AvgPool2d: 2-11                   [128, 128, 4, 4]          --
│    └─ResidualConv2dLayer: 2-12         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-25              [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-26                 [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-27            [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-28              [128, 128, 4, 4]          --
│    └─ResidualConv2dLayer: 2-13         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-29              [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-30                 [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-31            [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-32              [128, 128, 4, 4]          --
│    └─AvgPool2d: 2-14                   [128, 128, 2, 2]          --
│    └─ResidualConv2dLayer: 2-15         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-33              [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-34                 [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-35            [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-36              [128, 128, 2, 2]          --
│    └─ResidualConv2dLayer: 2-16         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-37              [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-38                 [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-39            [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-40              [128, 128, 2, 2]          --
│    └─AvgPool2d: 2-17                   [128, 128, 1, 1]          --
│    └─Flatten: 2-18                     [128, 128]                --
│    └─Linear: 2-19                      [128, 1]                  129
│    └─Sigmoid: 2-20                     [128, 1]                  --
==========================================================================================
Total params: 1,478,785
Trainable params: 1,478,785
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 39.66
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 652.74
Params size (MB): 5.92
Estimated Total Size (MB): 659.06
==========================================================================================
In [11]:
# model_graph = draw_graph(temp_d,
#                          input_size=(data_module.batch_size,)+data_module.input_shape,
#                          depth=3)
# model_graph.visual_graph

image.png

In [12]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        latent_dim: int = 128,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator(latent_dim=self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return torch.nn.functional.binary_cross_entropy(y_hat, y)

    def predict_step(self, batch):
        zeros, y_true = batch
        z = torch.randn(zeros.shape[0], model.hparams.latent_dim).type_as(zeros)
        imgs = (self(z).permute((0,2,3,1))+1.0)/2.0
        return imgs

    def training_step(self, batch):
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)*0.5
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
In [13]:
model = GAN(128)
In [14]:
# model = GAN.load_from_checkpoint("logs/gan/mnist/checkpoints/epoch=9-step=9380.ckpt")

Before Training...¶

In [15]:
logger = pl.loggers.CSVLogger("logs",
                              name="gan",
                              version="mnist")
In [16]:
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=10,
    logger=logger,
    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
In [17]:
imgs = trainer.predict(model, data_module)[0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Predicting: |          | 0/? [00:00<?, ?it/s]
In [18]:
imgs.shape
Out[18]:
torch.Size([80, 28, 28, 1])
In [19]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image
In [ ]:
# Train the GAN
trainer.fit(model, data_module)
In [ ]:
results = pd.read_csv("logs/gan/mnist/metrics.csv",sep=",")
In [31]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["g_loss"]))],
         results["g_loss"][np.logical_not(np.isnan(results["g_loss"]))],
         label="Loss")
plt.ylabel('Generator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image
In [32]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["d_loss"]))],
         results["d_loss"][np.logical_not(np.isnan(results["d_loss"]))],
         label="Loss")
plt.ylabel('Discriminator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image

After Training...¶

In [92]:
img = trainer.predict(model, data_module)[0]
Predicting: |          | 0/? [00:00<?, ?it/s]
In [93]:
imgs.shape
Out[93]:
torch.Size([80, 28, 28, 1])
In [31]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image

AC-GAN¶

Inspired from: https://lightning.ai/docs/pytorch/2.1.0/notebooks/lightning_examples/basic-gan.html

This version prevents mode collapse by using the class labels to creat an Auxiliary Classifier Generative Adversarial Network (AC-GAN). The discriminator can predict 11 classes: real digits 0 to 9 and class 10 for fakes (of any kind). The generator is provided an additional integer input, 0 to 9, to indicate which of the digits it should be generating (in addition to the random noise input vector). Mode collapse is therefore largely avoided since all 10 digit types can be requested and must be able to fool the discriminator. Thus, human supervision can help to avoid mode collapse, but requires labeled data sets.

In [6]:
class ResidualConv2dLayer(pl.LightningModule):
    def __init__(self,
                 size,
                 **kwargs):
        super().__init__(**kwargs)
        self.padding = torch.nn.ZeroPad2d((1,1,1,1))
        self.conv_layer = torch.nn.Conv2d(size,
                                          size,
                                          kernel_size=(3,3),
                                          stride=(1,1))
        self.norm = torch.nn.BatchNorm2d(size)
        self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        y = x
        y = self.padding(y)
        y = self.conv_layer(y)
        y = self.norm(y)
        y = self.activation(y)
        y = x + y
        return y
In [7]:
class Generator(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 n_channels = 1,
                 n_classes = 10,
                 **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.embedding = torch.nn.Embedding(n_classes,latent_dim)
        self.model = torch.nn.Sequential(*(
            [torch.nn.Linear(latent_dim,latent_dim*4),
             torch.nn.Unflatten(1,(latent_dim,2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2),
             torch.nn.Conv2d(latent_dim,latent_dim,(2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Upsample(scale_factor=2)] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] +
            [torch.nn.Conv2d(latent_dim,n_channels,1),
             torch.nn.Tanh()]        
        ))
    def forward(self, x, c):
        y = x * self.embedding(c)
        y = self.model(y)
        return y
In [8]:
temp_g = Generator(128)
summary(temp_g,
        input_size=[(5,128),(5,)],
        dtypes=[torch.float32,torch.long])
Out[8]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Generator                                [5, 1, 28, 28]            --
├─Embedding: 1-1                         [5, 128]                  1,280
├─Sequential: 1-2                        [5, 1, 28, 28]            --
│    └─Linear: 2-1                       [5, 512]                  66,048
│    └─Unflatten: 2-2                    [5, 128, 2, 2]            --
│    └─ResidualConv2dLayer: 2-3          [5, 128, 2, 2]            --
│    │    └─ZeroPad2d: 3-1               [5, 128, 4, 4]            --
│    │    └─Conv2d: 3-2                  [5, 128, 2, 2]            147,584
│    │    └─BatchNorm2d: 3-3             [5, 128, 2, 2]            256
│    │    └─LeakyReLU: 3-4               [5, 128, 2, 2]            --
│    └─ResidualConv2dLayer: 2-4          [5, 128, 2, 2]            --
│    │    └─ZeroPad2d: 3-5               [5, 128, 4, 4]            --
│    │    └─Conv2d: 3-6                  [5, 128, 2, 2]            147,584
│    │    └─BatchNorm2d: 3-7             [5, 128, 2, 2]            256
│    │    └─LeakyReLU: 3-8               [5, 128, 2, 2]            --
│    └─Upsample: 2-5                     [5, 128, 4, 4]            --
│    └─ResidualConv2dLayer: 2-6          [5, 128, 4, 4]            --
│    │    └─ZeroPad2d: 3-9               [5, 128, 6, 6]            --
│    │    └─Conv2d: 3-10                 [5, 128, 4, 4]            147,584
│    │    └─BatchNorm2d: 3-11            [5, 128, 4, 4]            256
│    │    └─LeakyReLU: 3-12              [5, 128, 4, 4]            --
│    └─ResidualConv2dLayer: 2-7          [5, 128, 4, 4]            --
│    │    └─ZeroPad2d: 3-13              [5, 128, 6, 6]            --
│    │    └─Conv2d: 3-14                 [5, 128, 4, 4]            147,584
│    │    └─BatchNorm2d: 3-15            [5, 128, 4, 4]            256
│    │    └─LeakyReLU: 3-16              [5, 128, 4, 4]            --
│    └─Upsample: 2-8                     [5, 128, 8, 8]            --
│    └─Conv2d: 2-9                       [5, 128, 7, 7]            65,664
│    └─ResidualConv2dLayer: 2-10         [5, 128, 7, 7]            --
│    │    └─ZeroPad2d: 3-17              [5, 128, 9, 9]            --
│    │    └─Conv2d: 3-18                 [5, 128, 7, 7]            147,584
│    │    └─BatchNorm2d: 3-19            [5, 128, 7, 7]            256
│    │    └─LeakyReLU: 3-20              [5, 128, 7, 7]            --
│    └─ResidualConv2dLayer: 2-11         [5, 128, 7, 7]            --
│    │    └─ZeroPad2d: 3-21              [5, 128, 9, 9]            --
│    │    └─Conv2d: 3-22                 [5, 128, 7, 7]            147,584
│    │    └─BatchNorm2d: 3-23            [5, 128, 7, 7]            256
│    │    └─LeakyReLU: 3-24              [5, 128, 7, 7]            --
│    └─Upsample: 2-12                    [5, 128, 14, 14]          --
│    └─ResidualConv2dLayer: 2-13         [5, 128, 14, 14]          --
│    │    └─ZeroPad2d: 3-25              [5, 128, 16, 16]          --
│    │    └─Conv2d: 3-26                 [5, 128, 14, 14]          147,584
│    │    └─BatchNorm2d: 3-27            [5, 128, 14, 14]          256
│    │    └─LeakyReLU: 3-28              [5, 128, 14, 14]          --
│    └─ResidualConv2dLayer: 2-14         [5, 128, 14, 14]          --
│    │    └─ZeroPad2d: 3-29              [5, 128, 16, 16]          --
│    │    └─Conv2d: 3-30                 [5, 128, 14, 14]          147,584
│    │    └─BatchNorm2d: 3-31            [5, 128, 14, 14]          256
│    │    └─LeakyReLU: 3-32              [5, 128, 14, 14]          --
│    └─Upsample: 2-15                    [5, 128, 28, 28]          --
│    └─ResidualConv2dLayer: 2-16         [5, 128, 28, 28]          --
│    │    └─ZeroPad2d: 3-33              [5, 128, 30, 30]          --
│    │    └─Conv2d: 3-34                 [5, 128, 28, 28]          147,584
│    │    └─BatchNorm2d: 3-35            [5, 128, 28, 28]          256
│    │    └─LeakyReLU: 3-36              [5, 128, 28, 28]          --
│    └─ResidualConv2dLayer: 2-17         [5, 128, 28, 28]          --
│    │    └─ZeroPad2d: 3-37              [5, 128, 30, 30]          --
│    │    └─Conv2d: 3-38                 [5, 128, 28, 28]          147,584
│    │    └─BatchNorm2d: 3-39            [5, 128, 28, 28]          256
│    │    └─LeakyReLU: 3-40              [5, 128, 28, 28]          --
│    └─Conv2d: 2-18                      [5, 1, 28, 28]            129
│    └─Tanh: 2-19                        [5, 1, 28, 28]            --
==========================================================================================
Total params: 1,611,521
Trainable params: 1,611,521
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 1.57
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 21.79
Params size (MB): 6.45
Estimated Total Size (MB): 28.24
==========================================================================================
In [9]:
# model_graph = draw_graph(temp_g,
#                          input_size=[(5,128),(5,)],
#                          dtypes=[torch.float32,torch.long],
#                          depth=3)
# model_graph.visual_graph

image.png

In [10]:
class Discriminator(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 n_channels=1,
                 n_classes=10):
        super().__init__()

        self.model = torch.nn.Sequential(*(
            [torch.nn.Conv2d(n_channels,latent_dim,1)]+
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 28
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 14
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 7
            [torch.nn.ZeroPad2d((1,1,1,1)), torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 4
            [torch.nn.AvgPool2d((2,2))] +
            [ResidualConv2dLayer(latent_dim) for _ in range(2)] + # 2
            [torch.nn.AvgPool2d((2,2)),
             torch.nn.Flatten(),
             torch.nn.Linear(latent_dim,n_classes+1)]
        ))

    def forward(self, x):
        y = x
        y = self.model(y)
        return y
In [11]:
temp_d = Discriminator(128)
summary(temp_d,input_size=(data_module.batch_size,)+data_module.input_shape)
Out[11]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Discriminator                            [128, 11]                 --
├─Sequential: 1-1                        [128, 11]                 --
│    └─Conv2d: 2-1                       [128, 128, 28, 28]        256
│    └─ResidualConv2dLayer: 2-2          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-1               [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-2                  [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-3             [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-4               [128, 128, 28, 28]        --
│    └─ResidualConv2dLayer: 2-3          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-5               [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-6                  [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-7             [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-8               [128, 128, 28, 28]        --
│    └─AvgPool2d: 2-4                    [128, 128, 14, 14]        --
│    └─ResidualConv2dLayer: 2-5          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-9               [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-10                 [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-11            [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-12              [128, 128, 14, 14]        --
│    └─ResidualConv2dLayer: 2-6          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-13              [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-14                 [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-15            [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-16              [128, 128, 14, 14]        --
│    └─AvgPool2d: 2-7                    [128, 128, 7, 7]          --
│    └─ResidualConv2dLayer: 2-8          [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-17              [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-18                 [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-19            [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-20              [128, 128, 7, 7]          --
│    └─ResidualConv2dLayer: 2-9          [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-21              [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-22                 [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-23            [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-24              [128, 128, 7, 7]          --
│    └─ZeroPad2d: 2-10                   [128, 128, 9, 9]          --
│    └─AvgPool2d: 2-11                   [128, 128, 4, 4]          --
│    └─ResidualConv2dLayer: 2-12         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-25              [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-26                 [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-27            [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-28              [128, 128, 4, 4]          --
│    └─ResidualConv2dLayer: 2-13         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-29              [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-30                 [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-31            [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-32              [128, 128, 4, 4]          --
│    └─AvgPool2d: 2-14                   [128, 128, 2, 2]          --
│    └─ResidualConv2dLayer: 2-15         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-33              [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-34                 [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-35            [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-36              [128, 128, 2, 2]          --
│    └─ResidualConv2dLayer: 2-16         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-37              [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-38                 [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-39            [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-40              [128, 128, 2, 2]          --
│    └─AvgPool2d: 2-17                   [128, 128, 1, 1]          --
│    └─Flatten: 2-18                     [128, 128]                --
│    └─Linear: 2-19                      [128, 11]                 1,419
==========================================================================================
Total params: 1,480,075
Trainable params: 1,480,075
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 39.66
==========================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 652.75
Params size (MB): 5.92
Estimated Total Size (MB): 659.07
==========================================================================================
In [12]:
# model_graph = draw_graph(temp_d,
#                          input_size=(data_module.batch_size,)+data_module.input_shape,
#                          depth=3)
# model_graph.visual_graph

image.png

In [13]:
class ACGAN(pl.LightningModule):
    def __init__(
        self,
        latent_dim: int = 128,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator(latent_dim=self.hparams.latent_dim)

    def forward(self, z, c):
        return self.generator(z, c)

    def adversarial_loss(self, y_hat, y):
        return torch.nn.functional.cross_entropy(y_hat, y)

    def predict_step(self, batch):
        zeros, y_true = batch
        z = torch.randn(zeros.shape[0], model.hparams.latent_dim).type_as(zeros)
        imgs = (self(z,y_true).permute((0,2,3,1))+1.0)/2.0
        return imgs

    def training_step(self, batch):
        imgs, y_true = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)
        c = torch.randint(10,y_true.shape,dtype=torch.long)
        c = c.type_as(y_true)
        
        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z,c)

        # adversarial loss is categorical cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self(z,c)), c)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        real_loss = self.adversarial_loss(self.discriminator(imgs), y_true)

        # how well can it label as fake?
        fake = torch.ones(imgs.size(0), dtype=torch.long)*10
        fake = fake.type_as(y_true)

        fake_loss = self.adversarial_loss(self.discriminator(self(z,c).detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
In [14]:
model = ACGAN(128)
In [15]:
# model = ACGAN.load_from_checkpoint("logs/ac-gan/mnist/checkpoints/epoch=9-step=9380.ckpt")

Before Training...¶

In [16]:
logger = pl.loggers.CSVLogger("logs",
                              name="ac-gan",
                              version="mnist")
In [17]:
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=10,
    logger=logger,
    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
In [18]:
imgs = trainer.predict(model, data_module)[0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Predicting: |          | 0/? [00:00<?, ?it/s]
In [19]:
imgs.shape
Out[19]:
torch.Size([80, 28, 28, 1])
In [20]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image
In [ ]:
trainer.fit(model, data_module)
In [20]:
results = pd.read_csv("logs/ac-gan/mnist/metrics.csv",sep=",")
In [21]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["g_loss"]))],
         results["g_loss"][np.logical_not(np.isnan(results["g_loss"]))],
         label="Loss")
plt.ylabel('Generator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image
In [22]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["d_loss"]))],
         results["d_loss"][np.logical_not(np.isnan(results["d_loss"]))],
         label="Loss")
plt.ylabel('Discriminator Loss')
plt.xlabel('Epoch')
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image

After Training...¶

In [144]:
imgs = trainer.predict(model, data_module)[0]
Predicting: |          | 0/? [00:00<?, ?it/s]
In [145]:
imgs.shape
Out[145]:
torch.Size([80, 28, 28, 1])
In [32]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image

DDPM¶

Denoising Diffusion Probabilistic Model¶

Adapted from: https://github.com/tcapelle/Diffusion-Models-pytorch/

The model below implements a Denoising Diffusion Probabilistic Model (DDPM) for learning to generate MNIST digits. Given the nature of stable diffusion training, the model is not significantly prone to mode collapse, partially due to the relatively large latent space size and residually-connected U net convolution network used, but mainly due to the denoising training regime of DDPMs. Much of the training and prediction steps were inspired by the code linked-to above, but everything was ported over to PyTorch Lightning, the noise-prediction architecture is based more on the GAN code above (using the principles of U net construction), a more standardized version of position encoding to provide the time-conditioning input was used, and MeanLogCoshError (MLCE) loss is used instead of MeanSquaredError (MSE) loss to make for slightly more stable training in general.

In [5]:
class SinePositionEncoding(pl.LightningModule):
    def __init__(self,
                 dim,
                 max_wavelength=10000.0,
                 **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.max_wavelength = torch.Tensor([max_wavelength])

    def forward(self, x):
        hidden_size = self.dim
        position = x
        min_freq = (1 / self.max_wavelength).type_as(x)
        timescales = torch.pow(
            min_freq,
            (2 * (torch.arange(hidden_size) // 2)).type_as(x)
            / torch.Tensor([hidden_size]).type_as(x)
        )
        angles = torch.unsqueeze(position, 1) * torch.unsqueeze(timescales, 0)
        cos_mask = (torch.arange(hidden_size) % 2).type_as(x)
        sin_mask = 1 - cos_mask
        positional_encodings = (
            torch.sin(angles) * sin_mask + torch.cos(angles) * cos_mask
        )
        return positional_encodings.type_as(x)
In [6]:
class TimeConditionedResidualConv2dLayer(pl.LightningModule):
    def __init__(self,
                 size,
                 t_dim = 256,
                 **kwargs):
        super().__init__(**kwargs)
        self.padding = torch.nn.ZeroPad2d((1,1,1,1))
        self.conv_layer = torch.nn.Conv2d(size,
                                          size,
                                          kernel_size=(3,3),
                                          stride=(1,1))
        self.norm = torch.nn.BatchNorm2d(size)
        self.embed_t = torch.nn.Linear(t_dim,size)
        # self.embed_t = torch.nn.Sequential(*[torch.nn.SiLU(),torch.nn.Linear(t_dim,size)])
        self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x, t = x
        y = x
        y = self.padding(y)
        y = self.conv_layer(y)
        y = self.norm(y)
        y = self.activation(y)
        y = y + x + self.embed_t(t)[:,:,None,None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return y, t
In [7]:
class TimeConditionedUNet(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 t_dim = 256,
                 n_channels = 1,
                 depth = 5,
                 **kwargs):
        super().__init__(**kwargs)

        # self.t_dim = t_dim
        self.time_encoding = SinePositionEncoding(t_dim)
        self.encoder0 = torch.nn.Conv2d(n_channels,latent_dim,1)
        self.encoder1 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 28
        self.encoder2 = torch.nn.AvgPool2d((2,2))
        self.encoder3 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 14
        self.encoder4 = torch.nn.AvgPool2d((2,2))
        self.encoder5 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 7
        self.encoder6 = torch.nn.ZeroPad2d((1,1,1,1))
        self.encoder7 = torch.nn.AvgPool2d((2,2))
        self.encoder8 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 4
        self.encoder9 = torch.nn.AvgPool2d((2,2))
        self.encoder10 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(depth)]) # 2
        self.encoder11 = torch.nn.AvgPool2d((2,2))
        self.encoder12 = torch.nn.Flatten()

        self.decoder0 = torch.nn.Linear(latent_dim,latent_dim*4)
        self.decoder1 = torch.nn.Unflatten(1,(latent_dim,2,2))
        self.decoder2 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
        self.decoder3 = torch.nn.Upsample(scale_factor=2)
        self.decoder4 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
        self.decoder5 = torch.nn.Upsample(scale_factor=2)
        self.decoder6 = torch.nn.Conv2d(latent_dim,latent_dim,(2,2))
        self.decoder7 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim=t_dim) for _ in range(2)])
        self.decoder8 = torch.nn.Upsample(scale_factor=2)
        self.decoder9 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim) for _ in range(2)])
        self.decoder10 = torch.nn.Upsample(scale_factor=2)
        self.decoder11 = torch.nn.Sequential(*[TimeConditionedResidualConv2dLayer(latent_dim, t_dim) for _ in range(2)])
        self.decoder12 = torch.nn.Conv2d(latent_dim,n_channels,1)

    def time_embedding(self, t):
        channels = self.t_dim
        inv_freq = (1.0 / (
            10000
            ** (torch.arange(0, channels, 2).float() / channels)
        )).type_as(t)
        time_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        time_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        time_enc = torch.cat([time_enc_a, time_enc_b], dim=-1)
        return time_enc
        
    def forward(self, x, t):
        t = self.time_encoding(t.type_as(x))
        # t = self.time_embedding(t.unsqueeze(-1).type_as(x))
        y = y0 = x
        y = self.encoder0(y)
        y = y1 = self.encoder1((y,t))[0]
        y = self.encoder2(y)
        y = y3 = self.encoder3((y,t))[0]
        y = self.encoder4(y)
        y = y5 = self.encoder5((y,t))[0]
        y = self.encoder6(y)
        y = self.encoder7(y)
        y = y8 = self.encoder8((y,t))[0]
        y = self.encoder9(y)
        y = y10 = self.encoder10((y,t))[0]
        y = self.encoder11(y)
        y = self.encoder12(y)

        y = self.decoder0(y)
        y = self.decoder1(y)
        y = self.decoder2((y,t))[0] + y10
        y = self.decoder3(y)
        y = self.decoder4((y,t))[0] + y8
        y = self.decoder5(y)
        y = self.decoder6(y)
        y = self.decoder7((y,t))[0] + y5
        y = self.decoder8(y)
        y = self.decoder9((y,t))[0] + y3
        y = self.decoder10(y)
        y = self.decoder11((y,t))[0] + y1
        y = self.decoder12(y) + y0

        return y
In [8]:
temp_u = TimeConditionedUNet(128)
summary(temp_u,input_size=[(data_module.batch_size,)+data_module.input_shape,
                           (data_module.batch_size,)])
Out[8]:
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
TimeConditionedUNet                                     [128, 1, 28, 28]          --
├─SinePositionEncoding: 1-1                             [128, 256]                --
├─Conv2d: 1-2                                           [128, 128, 28, 28]        256
├─Sequential: 1-3                                       [128, 128, 28, 28]        --
│    └─TimeConditionedResidualConv2dLayer: 2-1          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-1                              [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-2                                 [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-3                            [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-4                              [128, 128, 28, 28]        --
│    │    └─Linear: 3-5                                 [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-2          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-6                              [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-7                                 [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-8                            [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-9                              [128, 128, 28, 28]        --
│    │    └─Linear: 3-10                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-3          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-11                             [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-12                                [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-13                           [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-14                             [128, 128, 28, 28]        --
│    │    └─Linear: 3-15                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-4          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-16                             [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-17                                [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-18                           [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-19                             [128, 128, 28, 28]        --
│    │    └─Linear: 3-20                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-5          [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-21                             [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-22                                [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-23                           [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-24                             [128, 128, 28, 28]        --
│    │    └─Linear: 3-25                                [128, 128]                32,896
├─AvgPool2d: 1-4                                        [128, 128, 14, 14]        --
├─Sequential: 1-5                                       [128, 128, 14, 14]        --
│    └─TimeConditionedResidualConv2dLayer: 2-6          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-26                             [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-27                                [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-28                           [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-29                             [128, 128, 14, 14]        --
│    │    └─Linear: 3-30                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-7          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-31                             [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-32                                [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-33                           [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-34                             [128, 128, 14, 14]        --
│    │    └─Linear: 3-35                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-8          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-36                             [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-37                                [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-38                           [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-39                             [128, 128, 14, 14]        --
│    │    └─Linear: 3-40                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-9          [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-41                             [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-42                                [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-43                           [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-44                             [128, 128, 14, 14]        --
│    │    └─Linear: 3-45                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-10         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-46                             [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-47                                [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-48                           [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-49                             [128, 128, 14, 14]        --
│    │    └─Linear: 3-50                                [128, 128]                32,896
├─AvgPool2d: 1-6                                        [128, 128, 7, 7]          --
├─Sequential: 1-7                                       [128, 128, 7, 7]          --
│    └─TimeConditionedResidualConv2dLayer: 2-11         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-51                             [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-52                                [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-53                           [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-54                             [128, 128, 7, 7]          --
│    │    └─Linear: 3-55                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-12         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-56                             [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-57                                [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-58                           [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-59                             [128, 128, 7, 7]          --
│    │    └─Linear: 3-60                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-13         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-61                             [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-62                                [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-63                           [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-64                             [128, 128, 7, 7]          --
│    │    └─Linear: 3-65                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-14         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-66                             [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-67                                [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-68                           [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-69                             [128, 128, 7, 7]          --
│    │    └─Linear: 3-70                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-15         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-71                             [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-72                                [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-73                           [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-74                             [128, 128, 7, 7]          --
│    │    └─Linear: 3-75                                [128, 128]                32,896
├─ZeroPad2d: 1-8                                        [128, 128, 9, 9]          --
├─AvgPool2d: 1-9                                        [128, 128, 4, 4]          --
├─Sequential: 1-10                                      [128, 128, 4, 4]          --
│    └─TimeConditionedResidualConv2dLayer: 2-16         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-76                             [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-77                                [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-78                           [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-79                             [128, 128, 4, 4]          --
│    │    └─Linear: 3-80                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-17         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-81                             [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-82                                [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-83                           [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-84                             [128, 128, 4, 4]          --
│    │    └─Linear: 3-85                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-18         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-86                             [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-87                                [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-88                           [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-89                             [128, 128, 4, 4]          --
│    │    └─Linear: 3-90                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-19         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-91                             [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-92                                [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-93                           [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-94                             [128, 128, 4, 4]          --
│    │    └─Linear: 3-95                                [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-20         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-96                             [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-97                                [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-98                           [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-99                             [128, 128, 4, 4]          --
│    │    └─Linear: 3-100                               [128, 128]                32,896
├─AvgPool2d: 1-11                                       [128, 128, 2, 2]          --
├─Sequential: 1-12                                      [128, 128, 2, 2]          --
│    └─TimeConditionedResidualConv2dLayer: 2-21         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-101                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-102                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-103                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-104                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-105                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-22         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-106                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-107                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-108                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-109                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-110                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-23         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-111                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-112                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-113                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-114                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-115                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-24         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-116                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-117                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-118                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-119                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-120                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-25         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-121                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-122                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-123                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-124                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-125                               [128, 128]                32,896
├─AvgPool2d: 1-13                                       [128, 128, 1, 1]          --
├─Flatten: 1-14                                         [128, 128]                --
├─Linear: 1-15                                          [128, 512]                66,048
├─Unflatten: 1-16                                       [128, 128, 2, 2]          --
├─Sequential: 1-17                                      [128, 128, 2, 2]          --
│    └─TimeConditionedResidualConv2dLayer: 2-26         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-126                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-127                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-128                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-129                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-130                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-27         [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-131                            [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-132                               [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-133                          [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-134                            [128, 128, 2, 2]          --
│    │    └─Linear: 3-135                               [128, 128]                32,896
├─Upsample: 1-18                                        [128, 128, 4, 4]          --
├─Sequential: 1-19                                      [128, 128, 4, 4]          --
│    └─TimeConditionedResidualConv2dLayer: 2-28         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-136                            [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-137                               [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-138                          [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-139                            [128, 128, 4, 4]          --
│    │    └─Linear: 3-140                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-29         [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-141                            [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-142                               [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-143                          [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-144                            [128, 128, 4, 4]          --
│    │    └─Linear: 3-145                               [128, 128]                32,896
├─Upsample: 1-20                                        [128, 128, 8, 8]          --
├─Conv2d: 1-21                                          [128, 128, 7, 7]          65,664
├─Sequential: 1-22                                      [128, 128, 7, 7]          --
│    └─TimeConditionedResidualConv2dLayer: 2-30         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-146                            [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-147                               [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-148                          [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-149                            [128, 128, 7, 7]          --
│    │    └─Linear: 3-150                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-31         [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-151                            [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-152                               [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-153                          [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-154                            [128, 128, 7, 7]          --
│    │    └─Linear: 3-155                               [128, 128]                32,896
├─Upsample: 1-23                                        [128, 128, 14, 14]        --
├─Sequential: 1-24                                      [128, 128, 14, 14]        --
│    └─TimeConditionedResidualConv2dLayer: 2-32         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-156                            [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-157                               [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-158                          [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-159                            [128, 128, 14, 14]        --
│    │    └─Linear: 3-160                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-33         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-161                            [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-162                               [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-163                          [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-164                            [128, 128, 14, 14]        --
│    │    └─Linear: 3-165                               [128, 128]                32,896
├─Upsample: 1-25                                        [128, 128, 28, 28]        --
├─Sequential: 1-26                                      [128, 128, 28, 28]        --
│    └─TimeConditionedResidualConv2dLayer: 2-34         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-166                            [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-167                               [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-168                          [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-169                            [128, 128, 28, 28]        --
│    │    └─Linear: 3-170                               [128, 128]                32,896
│    └─TimeConditionedResidualConv2dLayer: 2-35         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-171                            [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-172                               [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-173                          [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-174                            [128, 128, 28, 28]        --
│    │    └─Linear: 3-175                               [128, 128]                32,896
├─Conv2d: 1-27                                          [128, 1, 28, 28]          129
=========================================================================================================
Total params: 6,457,857
Trainable params: 6,457,857
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 139.32
=========================================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 2040.02
Params size (MB): 25.83
Estimated Total Size (MB): 2066.25
=========================================================================================================
In [9]:
# This kind of UNet is large and difficult to view.
# The PNG version is also too fuzzy (bad enough for
# the GAN components above). Running the code will
# allow for a better visualization.

# model_graph = draw_graph(temp_u,
#                          input_size=[(data_module.batch_size,)+data_module.input_shape,
#                                      (data_module.batch_size,)],
#                          depth=3)
# model_graph.visual_graph

image.png

In [10]:
class DiffusionModel(pl.LightningModule):
    def __init__(
        self,
        latent_dim: int = 128,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        noise_steps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
        # networks
        self.unet = TimeConditionedUNet(latent_dim=self.hparams.latent_dim)
        # loss
        # self.my_loss = torch.nn.MSELoss()
        self.my_loss = torchmetrics.LogCoshError(num_outputs=np.prod([1,28,28]))
        # ema
        # self.ema = torch_ema.ExponentialMovingAverage(self.unet.parameters(), decay=0.995)
        
    def forward(self, x, t):
        return self.unet(x, t)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        return torch.optim.AdamW(self.unet.parameters(), lr=lr, betas=(b1, b2))

    def predict_step(self, batch):
        x, y_true = batch
        data = []
        x_t = torch.randn_like(x).type_as(x)
        data += [x_t]
        for i in reversed(range(1,self.noise_steps)):
            t = (torch.ones(x.shape[0])*i).type_as(y_true)
            # with self.ema.average_parameters():
            predicted_noise = self(x_t, t.type_as(x_t))
            alpha = self.alpha.type_as(x)[t][:, None, None, None]
            alpha_hat = self.alpha_hat.type_as(x)[t][:, None, None, None]
            beta = self.beta.type_as(x)[t][:, None, None, None]
            if i > 1:
                noise = torch.randn_like(x).type_as(x)
            else:
                noise = torch.zeros_like(x).type_as(x)
            x_t = 1 / torch.sqrt(alpha) * (x_t - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            data += [x_t]
        return (torch.clip(torch.permute(torch.stack(data),(0,1,3,4,2)),-1,1)+1.0)/2.0
        
    def training_step(self, batch):
        x, y_true = batch
        t = self.sample_timesteps(x.shape[0]).type_as(y_true)
        x_t, noise = self.noise_images(x,t)
        predicted_noise = self(x_t, t.type_as(x_t))
        # loss = self.my_loss(predicted_noise,
        #                     noise)
        loss = torch.mean(self.my_loss(torch.flatten(predicted_noise,start_dim=1),
                                       torch.flatten(noise,start_dim=1)))
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss

    # def on_before_zero_grad(self, *args, **kwargs):
    #     self.ema.update(self.unet.parameters())

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
        
    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def noise_images(self, x, t):
        alpha_hat = self.alpha_hat.type_as(x)
        sqrt_alpha_hat = torch.sqrt(alpha_hat[t])[:, None, None, None].type_as(x)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - alpha_hat[t])[:, None, None, None].type_as(x)
        e = torch.randn_like(x).type_as(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * e, e
In [11]:
model = DiffusionModel()
In [12]:
# model = DiffusionModel.load_from_checkpoint("logs/diffusion/mnist/checkpoints/epoch=999-step=469000.ckpt")
In [13]:
logger = pl.loggers.CSVLogger("logs",
                              name="diffusion",
                              version="mnist")
In [14]:
trainer = pl.Trainer(logger=logger,
                     max_epochs=1000,
                     enable_progress_bar=True,
                     log_every_n_steps=0,
                     callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
                    )
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Before Training...¶

In [15]:
result = trainer.predict(model,data_module)[0]
all_imgs = result.cpu().detach().numpy()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Predicting: |          | 0/? [00:00<?, ?it/s]
In [16]:
all_imgs.shape
Out[16]:
(1000, 80, 28, 28, 1)
In [17]:
imgs = all_imgs[-1]
In [18]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image
In [114]:
# Training Step
trainer.fit(model, data_module)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name    | Type                | Params
------------------------------------------------
0 | unet    | TimeConditionedUNet | 6.5 M 
1 | my_loss | LogCoshError        | 0     
------------------------------------------------
6.5 M     Trainable params
0         Non-trainable params
6.5 M     Total params
25.831    Total estimated model params size (MB)
Training: |          | 0/? [00:00<?, ?it/s]
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
In [195]:
results = pd.read_csv("logs/diffusion/mnist/metrics.csv",sep=",")
In [196]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["train_loss"]))],
         results["train_loss"][np.logical_not(np.isnan(results["train_loss"]))],
         label="Loss")
plt.ylabel('MeanLogCosh Loss')
plt.xlabel('Epoch')
plt.ylim([0.0,0.1])
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image

After Training...¶

In [53]:
result = trainer.predict(model,data_module)[0]
all_imgs = result.cpu().detach().numpy()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
Predicting: |          | 0/? [00:00<?, ?it/s]
In [54]:
all_imgs.shape
Out[54]:
(1000, 80, 28, 28, 1)
In [55]:
imgs = all_imgs[-1]
In [56]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image

CDDPM¶

Conditioned Denoising Diffusion Probabilistic Model¶

Adapted from: https://github.com/tcapelle/Diffusion-Models-pytorch/

The model below implements a Conditioned Denoising Diffusion Probabilistic Model (DDPM) for learning to generate (requested) MNIST digits. Given the nature of stable diffusion training, the model is not significantly prone to mode collapse, partially due to the relatively large latent space size and residually-connected U net convolution network used, but mainly due to the denoising training regime of DDPMs. Much of the training and prediction steps were inspired by the code linked-to above, but everything was ported over to PyTorch Lightning, the noise-prediction architecture is based more on the GAN code above (using the principles of U net construction), a more standardized version of position encoding to provide the time-conditioning input was used, and MeanLogCoshError (MLCE) loss is used instead of MeanSquaredError (MSE) loss to make for slightly more stable training in general.

Key difference from the DDPM model above: an additional integer value is provided as input representing the digit being requested. It is embedded using a torch.nn.Embedding layer and adds this to the time encodings.

Other changes: I rushed the training on this one a bit by reducing the number of diffusion steps (from 1000 to 200) and the size of the time embeddings (from 256 to 128) and training for fewer epochs (from 1000 to 500). While I compensated some for this by adding some additional residual layers to each block (from 2 to 5), this wasn't enough to overcome the limitations in the training regime. So, the results are a little muddier than I would like, but they do show that the model is working properly and with some additional hyperparameter adjustments (1000 steps, 256-dimensional time and condition embeddings), and the patience to train for about 1000 epochs, everything will sharpen up.

In [5]:
class SinePositionEncoding(pl.LightningModule):
    def __init__(self,
                 dim,
                 max_wavelength=10000.0,
                 **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.max_wavelength = torch.Tensor([max_wavelength])

    def forward(self, x):
        hidden_size = self.dim
        position = x
        min_freq = (1 / self.max_wavelength).type_as(x)
        timescales = torch.pow(
            min_freq,
            (2 * (torch.arange(hidden_size) // 2)).type_as(x)
            / torch.Tensor([hidden_size]).type_as(x)
        )
        angles = torch.unsqueeze(position, 1) * torch.unsqueeze(timescales, 0)
        cos_mask = (torch.arange(hidden_size) % 2).type_as(x)
        sin_mask = 1 - cos_mask
        positional_encodings = (
            torch.sin(angles) * sin_mask + torch.cos(angles) * cos_mask
        )
        return positional_encodings.type_as(x)
In [6]:
class ConditionedResidualConv2dLayer(pl.LightningModule):
    def __init__(self,
                 size,
                 c_dim = 256,
                 **kwargs):
        super().__init__(**kwargs)
        self.padding = torch.nn.ZeroPad2d((1,1,1,1))
        self.conv_layer = torch.nn.Conv2d(size,
                                          size,
                                          kernel_size=(3,3),
                                          stride=(1,1))
        self.norm = torch.nn.BatchNorm2d(size)
        self.embed_t = torch.nn.Linear(c_dim,size)
        self.activation = torch.nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x, c = x
        y = x
        y = self.padding(y)
        y = self.conv_layer(y)
        y = self.norm(y)
        y = self.activation(y)
        y = y + x + self.embed_t(c)[:,:,None,None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return y, c
In [7]:
class ConditionedUNet(pl.LightningModule):
    def __init__(self,
                 latent_dim,
                 c_dim = 256,
                 n_channels = 1,
                 depth = 5):
        super().__init__()

        self.time_encoding = SinePositionEncoding(c_dim)
        self.class_embedding = torch.nn.Embedding(10,c_dim)
        self.encoder0 = torch.nn.Conv2d(n_channels,latent_dim,1)
        self.encoder1 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)]) # 28
        self.encoder2 = torch.nn.AvgPool2d((2,2))
        self.encoder3 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)]) # 14
        self.encoder4 = torch.nn.AvgPool2d((2,2))
        self.encoder5 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)]) # 7
        self.encoder6 = torch.nn.ZeroPad2d((1,1,1,1))
        self.encoder7 = torch.nn.AvgPool2d((2,2))
        self.encoder8 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)]) # 4
        self.encoder9 = torch.nn.AvgPool2d((2,2))
        self.encoder10 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)]) # 2
        self.encoder11 = torch.nn.AvgPool2d((2,2))
        self.encoder12 = torch.nn.Flatten()

        self.decoder0 = torch.nn.Linear(latent_dim,latent_dim*4)
        self.decoder1 = torch.nn.Unflatten(1,(latent_dim,2,2))
        self.decoder2 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)])
        self.decoder3 = torch.nn.Upsample(scale_factor=2)
        self.decoder4 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)])
        self.decoder5 = torch.nn.Upsample(scale_factor=2)
        self.decoder6 = torch.nn.Conv2d(latent_dim,latent_dim,(2,2))
        self.decoder7 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)])
        self.decoder8 = torch.nn.Upsample(scale_factor=2)
        self.decoder9 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)])
        self.decoder10 = torch.nn.Upsample(scale_factor=2)
        self.decoder11 = torch.nn.Sequential(*[ConditionedResidualConv2dLayer(latent_dim, c_dim=c_dim) for _ in range(depth)])
        self.decoder12 = torch.nn.Conv2d(latent_dim,n_channels,1)
        
    def forward(self, x, c, t):
        cond = torch.squeeze(self.class_embedding(torch.unsqueeze(c,-1))) + \
               self.time_encoding(t.type_as(x))
        y = y0 = x
        y = self.encoder0(y)
        y = y1 = self.encoder1((y,cond))[0]
        y = self.encoder2(y)
        y = y3 = self.encoder3((y,cond))[0]
        y = self.encoder4(y)
        y = y5 = self.encoder5((y,cond))[0]
        y = self.encoder6(y)
        y = self.encoder7(y)
        y = y8 = self.encoder8((y,cond))[0]
        y = self.encoder9(y)
        y = y10 = self.encoder10((y,cond))[0]
        y = self.encoder11(y)
        y = self.encoder12(y)

        y = self.decoder0(y)
        y = self.decoder1(y)
        y = self.decoder2((y,cond))[0] + y10
        y = self.decoder3(y)
        y = self.decoder4((y,cond))[0] + y8
        y = self.decoder5(y)
        y = self.decoder6(y)
        y = self.decoder7((y,cond))[0] + y5
        y = self.decoder8(y)
        y = self.decoder9((y,cond))[0] + y3
        y = self.decoder10(y)
        y = self.decoder11((y,cond))[0] + y1
        y = self.decoder12(y) + y0

        return y
In [10]:
temp_u = ConditionedUNet(128)
summary(temp_u,input_size=[(data_module.batch_size,)+data_module.input_shape,
                           (data_module.batch_size,),
                           (data_module.batch_size,)],
        dtypes=[torch.float32,torch.long,torch.long])
Out[10]:
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
ConditionedUNet                                    [128, 1, 28, 28]          --
├─Embedding: 1-1                                   [128, 1, 256]             2,560
├─SinePositionEncoding: 1-2                        [128, 256]                --
├─Conv2d: 1-3                                      [128, 128, 28, 28]        256
├─Sequential: 1-4                                  [128, 128, 28, 28]        --
│    └─ConditionedResidualConv2dLayer: 2-1         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-1                         [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-2                            [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-3                       [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-4                         [128, 128, 28, 28]        --
│    │    └─Linear: 3-5                            [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-2         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-6                         [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-7                            [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-8                       [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-9                         [128, 128, 28, 28]        --
│    │    └─Linear: 3-10                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-3         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-11                        [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-12                           [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-13                      [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-14                        [128, 128, 28, 28]        --
│    │    └─Linear: 3-15                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-4         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-16                        [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-17                           [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-18                      [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-19                        [128, 128, 28, 28]        --
│    │    └─Linear: 3-20                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-5         [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-21                        [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-22                           [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-23                      [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-24                        [128, 128, 28, 28]        --
│    │    └─Linear: 3-25                           [128, 128]                32,896
├─AvgPool2d: 1-5                                   [128, 128, 14, 14]        --
├─Sequential: 1-6                                  [128, 128, 14, 14]        --
│    └─ConditionedResidualConv2dLayer: 2-6         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-26                        [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-27                           [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-28                      [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-29                        [128, 128, 14, 14]        --
│    │    └─Linear: 3-30                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-7         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-31                        [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-32                           [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-33                      [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-34                        [128, 128, 14, 14]        --
│    │    └─Linear: 3-35                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-8         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-36                        [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-37                           [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-38                      [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-39                        [128, 128, 14, 14]        --
│    │    └─Linear: 3-40                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-9         [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-41                        [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-42                           [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-43                      [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-44                        [128, 128, 14, 14]        --
│    │    └─Linear: 3-45                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-10        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-46                        [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-47                           [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-48                      [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-49                        [128, 128, 14, 14]        --
│    │    └─Linear: 3-50                           [128, 128]                32,896
├─AvgPool2d: 1-7                                   [128, 128, 7, 7]          --
├─Sequential: 1-8                                  [128, 128, 7, 7]          --
│    └─ConditionedResidualConv2dLayer: 2-11        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-51                        [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-52                           [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-53                      [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-54                        [128, 128, 7, 7]          --
│    │    └─Linear: 3-55                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-12        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-56                        [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-57                           [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-58                      [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-59                        [128, 128, 7, 7]          --
│    │    └─Linear: 3-60                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-13        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-61                        [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-62                           [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-63                      [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-64                        [128, 128, 7, 7]          --
│    │    └─Linear: 3-65                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-14        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-66                        [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-67                           [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-68                      [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-69                        [128, 128, 7, 7]          --
│    │    └─Linear: 3-70                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-15        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-71                        [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-72                           [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-73                      [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-74                        [128, 128, 7, 7]          --
│    │    └─Linear: 3-75                           [128, 128]                32,896
├─ZeroPad2d: 1-9                                   [128, 128, 9, 9]          --
├─AvgPool2d: 1-10                                  [128, 128, 4, 4]          --
├─Sequential: 1-11                                 [128, 128, 4, 4]          --
│    └─ConditionedResidualConv2dLayer: 2-16        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-76                        [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-77                           [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-78                      [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-79                        [128, 128, 4, 4]          --
│    │    └─Linear: 3-80                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-17        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-81                        [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-82                           [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-83                      [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-84                        [128, 128, 4, 4]          --
│    │    └─Linear: 3-85                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-18        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-86                        [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-87                           [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-88                      [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-89                        [128, 128, 4, 4]          --
│    │    └─Linear: 3-90                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-19        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-91                        [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-92                           [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-93                      [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-94                        [128, 128, 4, 4]          --
│    │    └─Linear: 3-95                           [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-20        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-96                        [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-97                           [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-98                      [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-99                        [128, 128, 4, 4]          --
│    │    └─Linear: 3-100                          [128, 128]                32,896
├─AvgPool2d: 1-12                                  [128, 128, 2, 2]          --
├─Sequential: 1-13                                 [128, 128, 2, 2]          --
│    └─ConditionedResidualConv2dLayer: 2-21        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-101                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-102                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-103                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-104                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-105                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-22        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-106                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-107                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-108                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-109                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-110                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-23        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-111                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-112                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-113                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-114                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-115                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-24        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-116                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-117                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-118                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-119                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-120                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-25        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-121                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-122                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-123                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-124                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-125                          [128, 128]                32,896
├─AvgPool2d: 1-14                                  [128, 128, 1, 1]          --
├─Flatten: 1-15                                    [128, 128]                --
├─Linear: 1-16                                     [128, 512]                66,048
├─Unflatten: 1-17                                  [128, 128, 2, 2]          --
├─Sequential: 1-18                                 [128, 128, 2, 2]          --
│    └─ConditionedResidualConv2dLayer: 2-26        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-126                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-127                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-128                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-129                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-130                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-27        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-131                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-132                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-133                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-134                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-135                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-28        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-136                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-137                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-138                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-139                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-140                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-29        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-141                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-142                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-143                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-144                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-145                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-30        [128, 128, 2, 2]          --
│    │    └─ZeroPad2d: 3-146                       [128, 128, 4, 4]          --
│    │    └─Conv2d: 3-147                          [128, 128, 2, 2]          147,584
│    │    └─BatchNorm2d: 3-148                     [128, 128, 2, 2]          256
│    │    └─LeakyReLU: 3-149                       [128, 128, 2, 2]          --
│    │    └─Linear: 3-150                          [128, 128]                32,896
├─Upsample: 1-19                                   [128, 128, 4, 4]          --
├─Sequential: 1-20                                 [128, 128, 4, 4]          --
│    └─ConditionedResidualConv2dLayer: 2-31        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-151                       [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-152                          [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-153                     [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-154                       [128, 128, 4, 4]          --
│    │    └─Linear: 3-155                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-32        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-156                       [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-157                          [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-158                     [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-159                       [128, 128, 4, 4]          --
│    │    └─Linear: 3-160                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-33        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-161                       [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-162                          [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-163                     [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-164                       [128, 128, 4, 4]          --
│    │    └─Linear: 3-165                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-34        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-166                       [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-167                          [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-168                     [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-169                       [128, 128, 4, 4]          --
│    │    └─Linear: 3-170                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-35        [128, 128, 4, 4]          --
│    │    └─ZeroPad2d: 3-171                       [128, 128, 6, 6]          --
│    │    └─Conv2d: 3-172                          [128, 128, 4, 4]          147,584
│    │    └─BatchNorm2d: 3-173                     [128, 128, 4, 4]          256
│    │    └─LeakyReLU: 3-174                       [128, 128, 4, 4]          --
│    │    └─Linear: 3-175                          [128, 128]                32,896
├─Upsample: 1-21                                   [128, 128, 8, 8]          --
├─Conv2d: 1-22                                     [128, 128, 7, 7]          65,664
├─Sequential: 1-23                                 [128, 128, 7, 7]          --
│    └─ConditionedResidualConv2dLayer: 2-36        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-176                       [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-177                          [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-178                     [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-179                       [128, 128, 7, 7]          --
│    │    └─Linear: 3-180                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-37        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-181                       [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-182                          [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-183                     [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-184                       [128, 128, 7, 7]          --
│    │    └─Linear: 3-185                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-38        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-186                       [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-187                          [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-188                     [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-189                       [128, 128, 7, 7]          --
│    │    └─Linear: 3-190                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-39        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-191                       [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-192                          [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-193                     [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-194                       [128, 128, 7, 7]          --
│    │    └─Linear: 3-195                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-40        [128, 128, 7, 7]          --
│    │    └─ZeroPad2d: 3-196                       [128, 128, 9, 9]          --
│    │    └─Conv2d: 3-197                          [128, 128, 7, 7]          147,584
│    │    └─BatchNorm2d: 3-198                     [128, 128, 7, 7]          256
│    │    └─LeakyReLU: 3-199                       [128, 128, 7, 7]          --
│    │    └─Linear: 3-200                          [128, 128]                32,896
├─Upsample: 1-24                                   [128, 128, 14, 14]        --
├─Sequential: 1-25                                 [128, 128, 14, 14]        --
│    └─ConditionedResidualConv2dLayer: 2-41        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-201                       [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-202                          [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-203                     [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-204                       [128, 128, 14, 14]        --
│    │    └─Linear: 3-205                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-42        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-206                       [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-207                          [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-208                     [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-209                       [128, 128, 14, 14]        --
│    │    └─Linear: 3-210                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-43        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-211                       [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-212                          [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-213                     [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-214                       [128, 128, 14, 14]        --
│    │    └─Linear: 3-215                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-44        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-216                       [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-217                          [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-218                     [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-219                       [128, 128, 14, 14]        --
│    │    └─Linear: 3-220                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-45        [128, 128, 14, 14]        --
│    │    └─ZeroPad2d: 3-221                       [128, 128, 16, 16]        --
│    │    └─Conv2d: 3-222                          [128, 128, 14, 14]        147,584
│    │    └─BatchNorm2d: 3-223                     [128, 128, 14, 14]        256
│    │    └─LeakyReLU: 3-224                       [128, 128, 14, 14]        --
│    │    └─Linear: 3-225                          [128, 128]                32,896
├─Upsample: 1-26                                   [128, 128, 28, 28]        --
├─Sequential: 1-27                                 [128, 128, 28, 28]        --
│    └─ConditionedResidualConv2dLayer: 2-46        [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-226                       [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-227                          [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-228                     [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-229                       [128, 128, 28, 28]        --
│    │    └─Linear: 3-230                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-47        [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-231                       [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-232                          [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-233                     [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-234                       [128, 128, 28, 28]        --
│    │    └─Linear: 3-235                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-48        [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-236                       [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-237                          [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-238                     [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-239                       [128, 128, 28, 28]        --
│    │    └─Linear: 3-240                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-49        [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-241                       [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-242                          [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-243                     [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-244                       [128, 128, 28, 28]        --
│    │    └─Linear: 3-245                          [128, 128]                32,896
│    └─ConditionedResidualConv2dLayer: 2-50        [128, 128, 28, 28]        --
│    │    └─ZeroPad2d: 3-246                       [128, 128, 30, 30]        --
│    │    └─Conv2d: 3-247                          [128, 128, 28, 28]        147,584
│    │    └─BatchNorm2d: 3-248                     [128, 128, 28, 28]        256
│    │    └─LeakyReLU: 3-249                       [128, 128, 28, 28]        --
│    │    └─Linear: 3-250                          [128, 128]                32,896
├─Conv2d: 1-28                                     [128, 1, 28, 28]          129
====================================================================================================
Total params: 9,171,457
Trainable params: 9,171,457
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 198.84
====================================================================================================
Input size (MB): 0.40
Forward/backward pass size (MB): 2867.22
Params size (MB): 36.69
Estimated Total Size (MB): 2904.31
====================================================================================================
In [11]:
# Similar to the DDPM above, too complex for reasonable PNG...

# model_graph = draw_graph(temp_u,
#                          input_size=[(data_module.batch_size,)+data_module.input_shape,
#                                      (data_module.batch_size,),
#                                      (data_module.batch_size,)],
#                          dtypes=[torch.float32,torch.long,torch.long],
#                          depth=3)
# model_graph.visual_graph
In [12]:
class ConditionedDiffusionModel(pl.LightningModule):
    def __init__(
        self,
        latent_dim: int = 128,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        noise_steps=200,
        beta_start=1e-4,
        beta_end=0.02,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        
        # networks
        self.unet = ConditionedUNet(latent_dim=self.hparams.latent_dim,
                                    c_dim=self.hparams.latent_dim)
        # loss
        # self.my_loss = torch.nn.MSELoss()
        self.my_loss = torchmetrics.LogCoshError(num_outputs=np.prod([1,28,28]))
        # ema
        # self.ema = torch_ema.ExponentialMovingAverage(self.unet.parameters(), decay=0.995)
        
    def forward(self, x, c, t):
        return self.unet(x, c, t)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        return torch.optim.AdamW(self.unet.parameters(), lr=lr, betas=(b1, b2))

    def predict_step(self, batch):
        x, y_true = batch
        data = []
        x_t = torch.randn_like(x).type_as(x)
        data += [x_t]
        for i in reversed(range(1,self.noise_steps)):
            t = (torch.ones(x.shape[0])*i).type_as(y_true)
            # with self.ema.average_parameters():
            predicted_noise = self(x_t, y_true, t.type_as(x_t))
            alpha = self.alpha.type_as(x)[t][:, None, None, None]
            alpha_hat = self.alpha_hat.type_as(x)[t][:, None, None, None]
            beta = self.beta.type_as(x)[t][:, None, None, None]
            if i > 1:
                noise = torch.randn_like(x).type_as(x)
            else:
                noise = torch.zeros_like(x).type_as(x)
            x_t = 1 / torch.sqrt(alpha) * (x_t - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            data += [x_t]
        return (torch.clip(torch.permute(torch.stack(data),(0,1,3,4,2)),-1,1)+1.0)/2.0
        
    def training_step(self, batch):
        x, y_true = batch
        t = self.sample_timesteps(x.shape[0]).type_as(y_true)
        x_t, noise = self.noise_images(x,t)
        predicted_noise = self(x_t, y_true, t.type_as(x_t))
        # loss = self.my_loss(predicted_noise,
        #                     noise)
        loss = torch.mean(self.my_loss(torch.flatten(predicted_noise,start_dim=1),
                                       torch.flatten(noise,start_dim=1)))
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss

    # def on_before_zero_grad(self, *args, **kwargs):
    #     self.ema.update(self.unet.parameters())

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
        
    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def noise_images(self, x, t):
        alpha_hat = self.alpha_hat.type_as(x)
        sqrt_alpha_hat = torch.sqrt(alpha_hat[t])[:, None, None, None].type_as(x)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - alpha_hat[t])[:, None, None, None].type_as(x)
        e = torch.randn_like(x).type_as(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * e, e
In [13]:
model = ConditionedDiffusionModel()
In [14]:
# model = ConditionedDiffusionModel.load_from_checkpoint("logs/conditioned-diffusion/mnist/checkpoints/epoch=325-step=152894.ckpt")
In [15]:
logger = pl.loggers.CSVLogger("logs",
                              name="conditioned-diffusion",
                              version="mnist")
In [16]:
trainer = pl.Trainer(logger=logger,
                     max_epochs=500,
                     enable_progress_bar=True,
                     log_every_n_steps=0,
                     callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=50)]
                    )
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Before Training...¶

In [17]:
result = trainer.predict(model,data_module)[0]
all_imgs = result.cpu().detach().numpy()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
/opt/conda/lib/python3.12/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Predicting: |          | 0/? [00:00<?, ?it/s]
In [18]:
all_imgs.shape
Out[18]:
(200, 80, 28, 28, 1)
In [19]:
imgs = all_imgs[-1]
In [20]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image
In [26]:
# Training Step
trainer.fit(model, data_module)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name    | Type            | Params
--------------------------------------------
0 | unet    | ConditionedUNet | 8.4 M 
1 | my_loss | LogCoshError    | 0     
--------------------------------------------
8.4 M     Trainable params
0         Non-trainable params
8.4 M     Total params
33.404    Total estimated model params size (MB)
Training: |          | 0/? [00:00<?, ?it/s]
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
In [39]:
results = pd.read_csv("logs/conditioned-diffusion/mnist/metrics.csv",sep=",")
In [41]:
plt.plot(results["epoch"][np.logical_not(np.isnan(results["train_loss"]))],
         results["train_loss"][np.logical_not(np.isnan(results["train_loss"]))],
         label="Loss")
plt.ylabel('MeanLogCosh Loss')
plt.xlabel('Epoch')
plt.ylim([0.0,0.1])
# plt.legend(loc='lower right')
plt.show()
No description has been provided for this image

After Training...¶

(Really not quite finished...)

In [22]:
result = trainer.predict(model,data_module)[0]
all_imgs = result.cpu().detach().numpy()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.
Predicting: |          | 0/? [00:00<?, ?it/s]
In [23]:
all_imgs.shape
Out[23]:
(200, 80, 28, 28, 1)
In [24]:
imgs = all_imgs[-1]
In [46]:
fig = plt.figure(figsize=(10, 8))
for i in range(imgs.shape[0]):
    plt.subplot(10, 8, i+1)
    plt.imshow(imgs[i], cmap="gray")
    plt.axis("off")
plt.show()
No description has been provided for this image
In [ ]: